#!/usr/bin/env python

#
# Copyright (c) 2012 Politecnico di Torino, Italy.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#
# Written by Simone Basso and Matteo Avalle.
#

''' Generates code for lopezdahab elliptic curves from three operand
    codes available at <http://www.hyperelliptic.org/EFD/g12o/>.
    The generated code uses the operations defined as macros in the
    file crypto/ec/ec2_lopezdahab.c, e.g. LOPEZDAHAB_SUM.  '''

import os.path
import sys
import urllib

# When True generate debugging statements along with macros
DEBUG = False

# The Lopez-Dahab code use an array of BIGNUM named ``regs``, i.e. "registers"
REGISTERS = set()
for c in range(ord('A'), ord('J') + 1):
    REGISTERS.add(chr(c))
for i in range(16):
    REGISTERS.add('t%d' % i)
for p in ('X', 'Y', 'Z'):
    for i in ('1', '2', '3'):
        REGISTERS.add('%s%s' % (p, i))
# Curve params
REGISTERS.add('a2')
REGISTERS.add('a6')

# The place from where we get three operand codes
URI = 'http://www.hyperelliptic.org/EFD/'

def print_defines():
    ''' Print defines for struct lopezdahab "registers" '''

    #
    # This function prints the definition of ``struct lopezdahab``
    # and the defines that map its array of BIGNUM to the meaningful
    # names used in the three operation codes.
    # We need to autogenerate also the structure definition because
    # the size of the BIGNUM array is not known in advance.
    #

    sys.stdout.write('/*\n')
    sys.stdout.write(' * Definitions autogenerated by: %s\n' % sys.argv[0])
    sys.stdout.write(' */\n')
    sys.stdout.write('\n')
    sys.stdout.write('struct lopezdahab {\n')
    sys.stdout.write('\tBN_CTX\t\t*ctx;\n')
    sys.stdout.write('\tconst EC_GROUP\t*group;\n')
    sys.stdout.write('\tBIGNUM\t\tregs[%d];\n' % len(REGISTERS))
    sys.stdout.write('\tunsigned int\tflags;\n')
    sys.stdout.write('};\n')
    sys.stdout.write('\n')

    for i, name in enumerate(sorted(list(REGISTERS))):
        sys.stdout.write('#define\t\tld_%s\tregs[%i]\n' % (name, i))
    sys.stdout.write('\n')

    sys.stdout.write('/* End autogenerated definitions */\n')
    sys.stdout.write('\n')

def process(uri):
    ''' Process three operand code '''

    #
    # Process the three operand code and generate a sequence of
    # macros that implements it.  The macros are the ones that
    # are defined by ``crypto/ec/ec2_lopezdahab.c``.
    # When DEBUG is True it also generates code that prints what
    # happens, macro per macro, so that one can debug.
    #

    filep = urllib.urlopen(uri)
    fname = os.path.basename(uri).replace('.op3', '').replace('-', '_')

    sys.stdout.write('/* %s (%s) */\n' % (os.path.basename(uri), URI))
    sys.stdout.write('static int\n')
    sys.stdout.write('lopezdahab_%s(struct lopezdahab *ld)\n' % fname)
    sys.stdout.write('{\n')

    for line in filep:
        line = line.strip()
        left, right = line.split('=')

        left = left.strip()
        if left not in REGISTERS:
            raise RuntimeError('Invalid left')

        right = right.strip()
        if '*' in right:
            operand1, operand2 = right.split('*')
            sys.stdout.write('\tLOPEZDAHAB_MUL('
              '&ld->ld_%s, &ld->ld_%s, &ld->ld_%s);\n' % (
              left, operand1, operand2))
            if DEBUG:
                sys.stdout.write('\tprintf("%s = %s * %s = ");\n' %
                  (left, operand1, operand2))
                sys.stdout.write('\tBN_print_fp(stdout, &ld->ld_%s);\n' %
                  operand1)
		sys.stdout.write('\tprintf(" * ");\n')
                sys.stdout.write('\tBN_print_fp(stdout, &ld->ld_%s);\n' %
                  operand2)
		sys.stdout.write('\tprintf(" = ");\n')
                sys.stdout.write('\tBN_print_fp(stdout, &ld->ld_%s);\n' %
                  left)
		sys.stdout.write('\tprintf("\\n");\n')

        elif '+' in right:
            operand1, operand2 = right.split('+')
            sys.stdout.write('\tLOPEZDAHAB_SUM('
              '&ld->ld_%s, &ld->ld_%s, &ld->ld_%s);\n' % (
              left, operand1, operand2))
            if DEBUG:
                sys.stdout.write('\tprintf("%s = %s + %s = ");\n' %
                  (left, operand1, operand2))
                sys.stdout.write('\tBN_print_fp(stdout, &ld->ld_%s);\n' %
                  operand1)
		sys.stdout.write('\tprintf(" + ");\n')
                sys.stdout.write('\tBN_print_fp(stdout, &ld->ld_%s);\n' %
                  operand2)
		sys.stdout.write('\tprintf(" = ");\n')
                sys.stdout.write('\tBN_print_fp(stdout, &ld->ld_%s);\n' %
                  left)
		sys.stdout.write('\tprintf("\\n");\n')

        elif right.endswith('^2'):
            operand1 = right.replace('^2', '')
            sys.stdout.write('\tLOPEZDAHAB_SQUARE('
              '&ld->ld_%s, &ld->ld_%s);\n' % (
              left, operand1))
            if DEBUG:
                sys.stdout.write('\tprintf("%s = (%s)^2 = (");\n' %
                  (left, operand1))
                sys.stdout.write('\tBN_print_fp(stdout, &ld->ld_%s);\n' %
                  operand1)
		sys.stdout.write('\tprintf(")^2 = ");\n')
                sys.stdout.write('\tBN_print_fp(stdout, &ld->ld_%s);\n' %
                  left)
		sys.stdout.write('\tprintf("\\n");\n')

        else:
            raise ValueError('Invalid operation')

    sys.stdout.write('\n')
    sys.stdout.write('\treturn (1);\n')
    sys.stdout.write('}\n\n')
    filep.close()

def main():
    ''' main function '''
    if len(sys.argv) == 1:
        sys.exit('Usage: %s uri ...\n' % sys.argv[0])

    elif len(sys.argv) == 2 and sys.argv[1] == '-d':
        print_defines()

    else:

        sys.stdout.write('/*\n')
        sys.stdout.write(' * Code autogenerated by: %s\n' % sys.argv[0])
        sys.stdout.write(' */\n')
        sys.stdout.write('\n')

        for uri in sys.argv[1:]:
            process(uri)

        sys.stdout.write('/* End autogenerated code */\n')
        sys.stdout.write('\n')

if __name__ == '__main__':
    main()
